前一日已經開始進行模型的訓練。本日將討論要如何確認或挑選訓練出來的模型是否真的好?真的朝著正確的方向在邁進呢?
在訓練的過程當中,很多情況只要是Training Code沒有異常的bugs的情況底下,在訓練集上的loss通常只會不斷下降。
這是否表示,我訓練出來的模型正在不斷變好呢? 這個答案,你知、我知、獨眼龍也知道,當然是不! 主因是在機械學習越來越發展以來,模型內的參數也隨之越來越多。在這種情況底下,模型很有可能會去把某些樣本或是巧合硬是記下來,換句人話就是模型把答案硬背下來了!(尤其是在Deep Learning時代下,參數又多,訓練集的樣本又每張都看過幾十幾百次的狀況下,更加的容易發生。)
上圖則是參考Wiki Overfitting條目當中的圖,其中綠色的線就是想表達一個過擬合的圖。
雖然它完全正確分出紅色藍色了,但我們事實上很害怕這樣的模型,如同上面所說的,它只是用極端的狀況去硬記訓練資料,進而在實際上無法套用到新的或實際的資料上。
為了去驗證我們的模型套用沒有學習過的資料時的效果,一般我們會保留一份資料用來檢查模型表現,這個資料子集通常我們就稱作為驗證集。(另外還有測試集,後續會再介紹。)
常見的具體實作,通常我們每次訓練模型到一個段落的時候,會使用當下的模型針對驗證集內的所有資料進行推論,並紀錄當下的模型在驗證集上各種metric的表現,進而評估模型的好壞。
以我們的Multi-Label Classification來說,我們最主要就是比較準確率(Accuracy)以及AUROC(Area Under the Receiver Operating Characteristic),這部份在Torchmetrics跟MONAI上都可以找到對應的函數可以使用。
本次的具體實作可以在每一個epoch的後面加上:
model.eval()
with torch.no_grad():
y_pred = torch.tensor([], dtype=torch.float32, device=device)
y = torch.tensor([], dtype=torch.long, device=device)
pbar = tqdm.tqdm(data_generators['VALIDATION'], total = len(processed_datasets['VALIDATION']) // data_generators['VALIDATION'].batch_size)
for batch in pbar:
val_images, val_labels = batch['img'].to(device), batch['labels'].to(device)
y_pred = torch.cat([y_pred, model(val_images)], dim=0)
y = torch.cat([y, val_labels], dim=0)
pbar.set_description('Validating ...')
y_prob = torch.nn.Sigmoid()(y_pred)
loss = loss_function(y_pred, y.float()).item()
acc_score = torchmetrics.functional.accuracy(y_prob, y, mdmc_average = 'global').item()
auc_score = monai.metrics.compute_roc_auc(y_prob, y, average='macro').item()
這裡要注意幾個重要的小細節,分別是
model.eval()
:做推論的模式切換,沒有做的話像是Dropout或是Batch Normalization就會根據訓練的模式跑出不正確的結果。with torch.no_grad()
:使用沒有梯度的模式進行運算,節省運算資源。新增了Validation以後的實作一樣放在Github對應的commit內,簡單執行,等待一下就可以得到結果:
# python src/train.py
----------
----------
epoch 24/25
Training Epoch 50/50train_loss: 0.2043: 98%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 49/50 [00:31<00:00, 1.57it/s]
epoch 24 average loss: 0.1974
Validating ...: : 88it [00:10, 8.11it/s]
current epoch: 24 current loss : 0.1849 current AUC: 0.5837 current accuracy: 0.9489 best AUC: 0.5849 at epoch: 23
----------
epoch 25/25
Training Epoch 50/50train_loss: 0.2136: 98%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 49/50 [00:29<00:00, 1.68it/s]
epoch 25 average loss: 0.1972
Validating ...: : 88it [00:11, 7.84it/s]
saved new best metric model
current epoch: 25 current loss : 0.1843 current AUC: 0.5866 current accuracy: 0.9490 best AUC: 0.5866 at epoch: 25
train completed, best_metric(AUC): 0.5866at epoch: 25
這裡可以注意到幾點分別為: